{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bayesnet_intro_md"
      },
      "source": [
        "# BayesNet"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "license_cell",
      "metadata": {
        "tags": [
          "remove-cell"
        ]
      },
      "source": [
        "GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,\nAtlanta, Georgia 30332-0415\nAll Rights Reserved\n\nAuthors: Frank Dellaert, et al. (see THANKS for the full author list)\n\nSee LICENSE for the license information"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bayesnet_desc_md"
      },
      "source": [
        "A `BayesNet` in GTSAM represents a directed graphical model, created by running sequential variable elimination (like Cholesky or QR factorization) on a `FactorGraph` or constructing from scratch.\n",
        "\n",
        "It is essentially a collection of `Conditional` objects, ordered according to the elimination order. Each conditional represents $P(\\text{variable} | \\text{parents})$, where the parents are variables that appear later in the elimination ordering.\n",
        "\n",
        "A Bayes net represents the joint probability distribution as a product of conditional probabilities stored in the net:\n",
        "\n",
        "$$\n",
        "P(X_1, X_2, \\dots, X_N) = \\prod_{i=1}^N P(X_i | \\text{Parents}(X_i))\n",
        "$$\n",
        "The total log-probability of an assignment is the sum of the log-probabilities of its conditionals:\n",
        "$$\n",
        "\\log P(X_1, \\dots, X_N) = \\sum_{i=1}^N \\log P(X_i | \\text{Parents}(X_i))\n",
        "$$\n",
        "\n",
        "Like `FactorGraph`, `BayesNet` is templated on the type of conditional it stores (e.g., `GaussianBayesNet`, `DiscreteBayesNet`)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bayesnet_colab_md"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/borglab/gtsam/blob/develop/gtsam/inference/doc/BayesNet.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bayesnet_pip_code",
        "tags": [
          "remove-cell"
        ]
      },
      "outputs": [],
      "source": [
        "%pip install --quiet gtsam-develop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 44,
      "metadata": {
        "id": "bayesnet_import_code"
      },
      "outputs": [],
      "source": [
        "import gtsam\n",
        "import numpy as np\n",
        "import graphviz\n",
        "\n",
        "# We need concrete graph types and elimination to get a BayesNet\n",
        "from gtsam import GaussianFactorGraph, Ordering, GaussianBayesNet\n",
        "# For the Asia example\n",
        "from gtsam import DiscreteBayesNet, DiscreteConditional, DiscreteKeys, DiscreteValues, symbol\n",
        "from gtsam import symbol_shorthand\n",
        "\n",
        "X = symbol_shorthand.X\n",
        "L = symbol_shorthand.L"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bayesnet_create_md"
      },
      "source": [
        "## Creating a BayesNet (via Elimination)\n",
        "\n",
        "BayesNets are typically obtained by eliminating a `FactorGraph`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bayesnet_eliminate_code",
        "outputId": "01234567-89ab-cdef-0123-456789abcdef"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Original Factor Graph:\n",
            "\n",
            "size: 3\n",
            "factor 0: \n",
            "  A[x0] = [\n",
            "\t-1\n",
            "]\n",
            "  b = [ 0 ]\n",
            "  Noise model: unit (1) \n",
            "factor 1: \n",
            "  A[x0] = [\n",
            "\t-1\n",
            "]\n",
            "  A[x1] = [\n",
            "\t1\n",
            "]\n",
            "  b = [ 0 ]\n",
            "  Noise model: unit (1) \n",
            "factor 2: \n",
            "  A[x1] = [\n",
            "\t-1\n",
            "]\n",
            "  A[x2] = [\n",
            "\t1\n",
            "]\n",
            "  b = [ 0 ]\n",
            "  Noise model: unit (1) \n"
          ]
        }
      ],
      "source": [
        "# Create a simple Gaussian Factor Graph P(x0) P(x1|x0) P(x2|x1)\n",
        "graph = GaussianFactorGraph()\n",
        "model = gtsam.noiseModel.Isotropic.Sigma(1, 1.0)\n",
        "graph.add(X(0), -np.eye(1), np.zeros(1), model)\n",
        "graph.add(X(0), -np.eye(1), X(1), np.eye(1), np.zeros(1), model)\n",
        "graph.add(X(1), -np.eye(1), X(2), np.eye(1), np.zeros(1), model)\n",
        "print(\"Original Factor Graph:\")\n",
        "graph.print()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 46,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Resulting BayesNet:\n",
            "\n",
            "size: 3\n",
            "conditional 0:  p(x0 | x1)\n",
            "  R = [ 1.41421 ]\n",
            "  S[x1] = [ -0.707107 ]\n",
            "  d = [ 0 ]\n",
            "  logNormalizationConstant: -0.572365\n",
            "  No noise model\n",
            "conditional 1:  p(x1 | x2)\n",
            "  R = [ 1.22474 ]\n",
            "  S[x2] = [ -0.816497 ]\n",
            "  d = [ 0 ]\n",
            "  logNormalizationConstant: -0.716206\n",
            "  No noise model\n",
            "conditional 2:  p(x2)\n",
            "  R = [ 0.57735 ]\n",
            "  d = [ 0 ]\n",
            "  mean: 1 elements\n",
            "  x2: 0\n",
            "  logNormalizationConstant: -1.46824\n",
            "  No noise model\n"
          ]
        }
      ],
      "source": [
        "# Eliminate sequentially using a specific ordering\n",
        "ordering = Ordering([X(0), X(1), X(2)])\n",
        "bayes_net = graph.eliminateSequential(ordering)\n",
        "\n",
        "print(\"\\nResulting BayesNet:\")\n",
        "bayes_net.print()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bayesnet_props_md"
      },
      "source": [
        "## Properties and Access\n",
        "\n",
        "A `BayesNet` provides access to its constituent conditionals and basic properties."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 47,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bayesnet_access_code",
        "outputId": "12345678-9abc-def0-1234-56789abcdef0"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "BayesNet size: 3\n",
            "Conditional at index 1: \n",
            "GaussianConditional p(x1 | x2)\n",
            "  R = [ 1.22474 ]\n",
            "  S[x2] = [ -0.816497 ]\n",
            "  d = [ 0 ]\n",
            "  logNormalizationConstant: -0.716206\n",
            "  No noise model\n",
            "Keys in BayesNet: x0x1x2\n"
          ]
        }
      ],
      "source": [
        "print(f\"BayesNet size: {bayes_net.size()}\")\n",
        "\n",
        "# Access conditional by index\n",
        "conditional1 = bayes_net.at(1)\n",
        "print(\"Conditional at index 1: \")\n",
        "conditional1.print()\n",
        "\n",
        "# Get all keys involved\n",
        "bn_keys = bayes_net.keys()\n",
        "print(f\"Keys in BayesNet: {bn_keys}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bayesnet_eval_md"
      },
      "source": [
        "## Evaluation and Solution\n",
        "\n",
        "The `logProbability(Values)` method computes the log probability of a variable assignment given the conditional distributions in the Bayes net. For Gaussian Bayes nets, the `optimize()` method can be used to find the maximum likelihood estimate (MLE) solution via back-substitution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 48,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bayesnet_eval_code",
        "outputId": "23456789-abcd-ef01-2345-6789abcdef01"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Log Probability at 0,0,0]: -2.7568155996140185\n",
            "Optimized Solution (MLE):\n",
            "VectorValues: 3 elements\n",
            "  x0: 0\n",
            "  x1: 0\n",
            "  x2: 0\n"
          ]
        }
      ],
      "source": [
        "# For GaussianBayesNet, we use VectorValues\n",
        "mle_solution = bayes_net.optimize()\n",
        "\n",
        "# Calculate log probability (requires providing values for all variables)\n",
        "log_prob = bayes_net.logProbability(mle_solution)\n",
        "print(f\"Log Probability at {mle_solution.at(X(0))[0]:.0f},{mle_solution.at(X(1))[0]:.0f},{mle_solution.at(X(2))[0]:.0f}]: {log_prob}\")\n",
        "\n",
        "print(\"Optimized Solution (MLE):\")\n",
        "mle_solution.print()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bayesnet_viz_md"
      },
      "source": [
        "## Visualization\n",
        "\n",
        "Bayes nets can also be visualized using Graphviz."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 49,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bayesnet_dot_code",
        "outputId": "3456789a-bcde-f012-3456-789abcdef012"
      },
      "outputs": [
        {
          "data": {
            "image/svg+xml": [
              "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
              "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
              " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
              "<!-- Generated by graphviz version 2.50.0 (0)\n",
              " -->\n",
              "<!-- Pages: 1 -->\n",
              "<svg width=\"62pt\" height=\"188pt\"\n",
              " viewBox=\"0.00 0.00 62.00 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
              "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
              "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 58,-184 58,4 -4,4\"/>\n",
              "<!-- var8646911284551352320 -->\n",
              "<g id=\"node1\" class=\"node\">\n",
              "<title>var8646911284551352320</title>\n",
              "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
              "<text text-anchor=\"middle\" x=\"27\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x0</text>\n",
              "</g>\n",
              "<!-- var8646911284551352321 -->\n",
              "<g id=\"node2\" class=\"node\">\n",
              "<title>var8646911284551352321</title>\n",
              "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
              "<text text-anchor=\"middle\" x=\"27\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x1</text>\n",
              "</g>\n",
              "<!-- var8646911284551352321&#45;&gt;var8646911284551352320 -->\n",
              "<g id=\"edge2\" class=\"edge\">\n",
              "<title>var8646911284551352321&#45;&gt;var8646911284551352320</title>\n",
              "<path fill=\"none\" stroke=\"black\" d=\"M27,-71.7C27,-63.98 27,-54.71 27,-46.11\"/>\n",
              "<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-46.1 27,-36.1 23.5,-46.1 30.5,-46.1\"/>\n",
              "</g>\n",
              "<!-- var8646911284551352322 -->\n",
              "<g id=\"node3\" class=\"node\">\n",
              "<title>var8646911284551352322</title>\n",
              "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
              "<text text-anchor=\"middle\" x=\"27\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2</text>\n",
              "</g>\n",
              "<!-- var8646911284551352322&#45;&gt;var8646911284551352321 -->\n",
              "<g id=\"edge1\" class=\"edge\">\n",
              "<title>var8646911284551352322&#45;&gt;var8646911284551352321</title>\n",
              "<path fill=\"none\" stroke=\"black\" d=\"M27,-143.7C27,-135.98 27,-126.71 27,-118.11\"/>\n",
              "<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-118.1 27,-108.1 23.5,-118.1 30.5,-118.1\"/>\n",
              "</g>\n",
              "</g>\n",
              "</svg>\n"
            ],
            "text/plain": [
              "<graphviz.sources.Source at 0x18b7757b020>"
            ]
          },
          "execution_count": 49,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "graphviz.Source(bayes_net.dot())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bayesnet_discrete_md"
      },
      "source": [
        "## Example: DiscreteBayesNet (Asia Network)\n",
        "\n",
        "While the previous examples focused on `GaussianBayesNet`, GTSAM also supports `DiscreteBayesNet` for representing probability distributions over discrete variables. Here we construct the classic 'Asia' network example directly by adding `DiscreteConditional` objects."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 50,
      "metadata": {
        "id": "bayesnet_discrete_imports_code"
      },
      "outputs": [],
      "source": [
        "# Define keys for the Asia network variables\n",
        "A = symbol('A', 8) # Visit to Asia?\n",
        "S = symbol('S', 7) # Smoker?\n",
        "T = symbol('T', 6) # Tuberculosis?\n",
        "L = symbol('L', 5) # Lung Cancer?\n",
        "B = symbol('B', 4) # Bronchitis?\n",
        "E = symbol('E', 3) # Tuberculosis or Lung Cancer?\n",
        "X = symbol('X', 2) # Positive X-Ray?\n",
        "D = symbol('D', 1) # Dyspnea (Shortness of breath)?\n",
        "\n",
        "# Define cardinalities (all are binary in this case)\n",
        "cardinalities = { A: 2, S: 2, T: 2, L: 2, B: 2, E: 2, X: 2, D: 2 }\n",
        "\n",
        "# Helper to create DiscreteKeys object\n",
        "def make_keys(keys_list):\n",
        "    dk = DiscreteKeys()\n",
        "    for k in keys_list:\n",
        "        dk.push_back((k, cardinalities[k]))\n",
        "    return dk"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 51,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bayesnet_discrete_build_code",
        "outputId": "456789ab-cdef-0123-4567-89abcdef0123"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Asia Bayes Net:\n",
            "DiscreteBayesNet\n",
            " \n",
            "size: 8\n",
            "conditional 0:  P( D1 | E3 B4 ):\n",
            " Choice(E3) \n",
            " 0 Choice(D1) \n",
            " 0 0 Choice(B4) \n",
            " 0 0 0 Leaf  0.9\n",
            " 0 0 1 Leaf  0.2\n",
            " 0 1 Choice(B4) \n",
            " 0 1 0 Leaf  0.1\n",
            " 0 1 1 Leaf  0.8\n",
            " 1 Choice(D1) \n",
            " 1 0 Choice(B4) \n",
            " 1 0 0 Leaf  0.3\n",
            " 1 0 1 Leaf  0.1\n",
            " 1 1 Choice(B4) \n",
            " 1 1 0 Leaf  0.7\n",
            " 1 1 1 Leaf  0.9\n",
            "\n",
            "conditional 1:  P( X2 | E3 ):\n",
            " Choice(X2) \n",
            " 0 Choice(E3) \n",
            " 0 0 Leaf 0.95\n",
            " 0 1 Leaf 0.02\n",
            " 1 Choice(E3) \n",
            " 1 0 Leaf 0.05\n",
            " 1 1 Leaf 0.98\n",
            "\n",
            "conditional 2:  P( E3 | T6 L5 ):\n",
            " Choice(T6) \n",
            " 0 Choice(L5) \n",
            " 0 0 Choice(E3) \n",
            " 0 0 0 Leaf    1\n",
            " 0 0 1 Leaf    0\n",
            " 0 1 Choice(E3) \n",
            " 0 1 0 Leaf    0\n",
            " 0 1 1 Leaf    1\n",
            " 1 Choice(L5) \n",
            " 1 0 Choice(E3) \n",
            " 1 0 0 Leaf    0\n",
            " 1 0 1 Leaf    1\n",
            " 1 1 Choice(E3) \n",
            " 1 1 0 Leaf    0\n",
            " 1 1 1 Leaf    1\n",
            "\n",
            "conditional 3:  P( B4 | S7 ):\n",
            " Choice(S7) \n",
            " 0 Choice(B4) \n",
            " 0 0 Leaf  0.7\n",
            " 0 1 Leaf  0.3\n",
            " 1 Choice(B4) \n",
            " 1 0 Leaf  0.4\n",
            " 1 1 Leaf  0.6\n",
            "\n",
            "conditional 4:  P( L5 | S7 ):\n",
            " Choice(S7) \n",
            " 0 Choice(L5) \n",
            " 0 0 Leaf 0.99\n",
            " 0 1 Leaf 0.01\n",
            " 1 Choice(L5) \n",
            " 1 0 Leaf  0.9\n",
            " 1 1 Leaf  0.1\n",
            "\n",
            "conditional 5:  P( T6 | A8 ):\n",
            " Choice(T6) \n",
            " 0 Choice(A8) \n",
            " 0 0 Leaf 0.99\n",
            " 0 1 Leaf 0.95\n",
            " 1 Choice(A8) \n",
            " 1 0 Leaf 0.01\n",
            " 1 1 Leaf 0.05\n",
            "\n",
            "conditional 6:  P( S7 ):\n",
            " Leaf  0.5\n",
            "\n",
            "conditional 7:  P( A8 ):\n",
            " Choice(A8) \n",
            " 0 Leaf 0.99\n",
            " 1 Leaf 0.01\n",
            "\n"
          ]
        }
      ],
      "source": [
        "# Create the DiscreteBayesNet\n",
        "asia_net = DiscreteBayesNet()\n",
        "\n",
        "# Helper function to create parent list in correct format\n",
        "def make_parent_tuples(parent_keys):\n",
        "    return [(pk, cardinalities[pk]) for pk in parent_keys]\n",
        "\n",
        "# P(D | E, B) - Dyspnea given Either and Bronchitis\n",
        "asia_net.add(DiscreteConditional((D, cardinalities[D]), make_parent_tuples([E, B]), \"9/1 2/8 3/7 1/9\"))\n",
        "\n",
        "# P(X | E) - X-Ray result given Either\n",
        "asia_net.add(DiscreteConditional((X, cardinalities[X]), make_parent_tuples([E]), \"95/5 2/98\"))\n",
        "\n",
        "# P(E | T, L) - Either Tub. or Lung Cancer (OR gate)\n",
        "# \"F T T T\" means P(E=1|T=0,L=0)=0, P(E=1|T=0,L=1)=1, P(E=1|T=1,L=0)=1, P(E=1|T=1,L=1)=1\n",
        "asia_net.add(DiscreteConditional((E, cardinalities[E]), make_parent_tuples([T, L]), \"F T T T\"))\n",
        "\n",
        "# P(B | S) - Bronchitis given Smoker\n",
        "asia_net.add(DiscreteConditional((B, cardinalities[B]), make_parent_tuples([S]), \"70/30 40/60\"))\n",
        "\n",
        "# P(L | S) - Lung Cancer given Smoker\n",
        "asia_net.add(DiscreteConditional((L, cardinalities[L]), make_parent_tuples([S]), \"99/1 90/10\"))\n",
        "\n",
        "# P(T | A) - Tuberculosis given Asia\n",
        "asia_net.add(DiscreteConditional((T, cardinalities[T]), make_parent_tuples([A]), \"99/1 95/5\"))\n",
        "\n",
        "# P(S) - Prior on Smoking\n",
        "asia_net.add(DiscreteConditional((S, cardinalities[S]), [], \"1/1\")) # or \"50/50\"\n",
        "\n",
        "# Add conditional probability tables (CPTs) using C++ sugar syntax\n",
        "# P(A) - Prior on Asia\n",
        "asia_net.add(DiscreteConditional((A, cardinalities[A]), [], \"99/1\"))\n",
        "\n",
        "print(\"Asia Bayes Net:\")\n",
        "asia_net.print()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 58,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bayesnet_discrete_viz_eval_code",
        "outputId": "56789abc-def0-1234-5678-9abcdef01234"
      },
      "outputs": [
        {
          "data": {
            "image/svg+xml": [
              "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
              "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
              " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
              "<!-- Generated by graphviz version 2.50.0 (0)\n",
              " -->\n",
              "<!-- Pages: 1 -->\n",
              "<svg width=\"189pt\" height=\"260pt\"\n",
              " viewBox=\"0.00 0.00 189.00 260.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
              "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n",
              "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 185,-256 185,4 -4,4\"/>\n",
              "<!-- var4683743612465315848 -->\n",
              "<g id=\"node1\" class=\"node\">\n",
              "<title>var4683743612465315848</title>\n",
              "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-234\" rx=\"27\" ry=\"18\"/>\n",
              "<text text-anchor=\"middle\" x=\"27\" y=\"-230.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">A8</text>\n",
              "</g>\n",
              "<!-- var6052837899185946630 -->\n",
              "<g id=\"node7\" class=\"node\">\n",
              "<title>var6052837899185946630</title>\n",
              "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
              "<text text-anchor=\"middle\" x=\"27\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">T6</text>\n",
              "</g>\n",
              "<!-- var4683743612465315848&#45;&gt;var6052837899185946630 -->\n",
              "<g id=\"edge1\" class=\"edge\">\n",
              "<title>var4683743612465315848&#45;&gt;var6052837899185946630</title>\n",
              "<path fill=\"none\" stroke=\"black\" d=\"M27,-215.7C27,-207.98 27,-198.71 27,-190.11\"/>\n",
              "<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-190.1 27,-180.1 23.5,-190.1 30.5,-190.1\"/>\n",
              "</g>\n",
              "<!-- var4755801206503243780 -->\n",
              "<g id=\"node2\" class=\"node\">\n",
              "<title>var4755801206503243780</title>\n",
              "<ellipse fill=\"none\" stroke=\"black\" cx=\"154\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
              "<text text-anchor=\"middle\" x=\"154\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">B4</text>\n",
              "</g>\n",
              "<!-- var4899916394579099649 -->\n",
              "<g id=\"node3\" class=\"node\">\n",
              "<title>var4899916394579099649</title>\n",
              "<ellipse fill=\"none\" stroke=\"black\" cx=\"154\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
              "<text text-anchor=\"middle\" x=\"154\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">D1</text>\n",
              "</g>\n",
              "<!-- var4755801206503243780&#45;&gt;var4899916394579099649 -->\n",
              "<g id=\"edge8\" class=\"edge\">\n",
              "<title>var4755801206503243780&#45;&gt;var4899916394579099649</title>\n",
              "<path fill=\"none\" stroke=\"black\" d=\"M154,-71.7C154,-63.98 154,-54.71 154,-46.11\"/>\n",
              "<polygon fill=\"black\" stroke=\"black\" points=\"157.5,-46.1 154,-36.1 150.5,-46.1 157.5,-46.1\"/>\n",
              "</g>\n",
              "<!-- var4971973988617027587 -->\n",
              "<g id=\"node4\" class=\"node\">\n",
              "<title>var4971973988617027587</title>\n",
              "<ellipse fill=\"none\" stroke=\"black\" cx=\"82\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
              "<text text-anchor=\"middle\" x=\"82\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">E3</text>\n",
              "</g>\n",
              "<!-- var4971973988617027587&#45;&gt;var4899916394579099649 -->\n",
              "<g id=\"edge7\" class=\"edge\">\n",
              "<title>var4971973988617027587&#45;&gt;var4899916394579099649</title>\n",
              "<path fill=\"none\" stroke=\"black\" d=\"M96.57,-74.83C106.75,-64.94 120.52,-51.55 132.03,-40.36\"/>\n",
              "<polygon fill=\"black\" stroke=\"black\" points=\"134.47,-42.87 139.2,-33.38 129.59,-37.85 134.47,-42.87\"/>\n",
              "</g>\n",
              "<!-- var6341068275337658370 -->\n",
              "<g id=\"node8\" class=\"node\">\n",
              "<title>var6341068275337658370</title>\n",
              "<ellipse fill=\"none\" stroke=\"black\" cx=\"82\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
              "<text text-anchor=\"middle\" x=\"82\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X2</text>\n",
              "</g>\n",
              "<!-- var4971973988617027587&#45;&gt;var6341068275337658370 -->\n",
              "<g id=\"edge6\" class=\"edge\">\n",
              "<title>var4971973988617027587&#45;&gt;var6341068275337658370</title>\n",
              "<path fill=\"none\" stroke=\"black\" d=\"M82,-71.7C82,-63.98 82,-54.71 82,-46.11\"/>\n",
              "<polygon fill=\"black\" stroke=\"black\" points=\"85.5,-46.1 82,-36.1 78.5,-46.1 85.5,-46.1\"/>\n",
              "</g>\n",
              "<!-- var5476377146882523141 -->\n",
              "<g id=\"node5\" class=\"node\">\n",
              "<title>var5476377146882523141</title>\n",
              "<ellipse fill=\"none\" stroke=\"black\" cx=\"99\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
              "<text text-anchor=\"middle\" x=\"99\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">L5</text>\n",
              "</g>\n",
              "<!-- var5476377146882523141&#45;&gt;var4971973988617027587 -->\n",
              "<g id=\"edge5\" class=\"edge\">\n",
              "<title>var5476377146882523141&#45;&gt;var4971973988617027587</title>\n",
              "<path fill=\"none\" stroke=\"black\" d=\"M94.88,-144.05C92.99,-136.26 90.7,-126.82 88.58,-118.08\"/>\n",
              "<polygon fill=\"black\" stroke=\"black\" points=\"91.96,-117.17 86.2,-108.28 85.15,-118.82 91.96,-117.17\"/>\n",
              "</g>\n",
              "<!-- var5980780305148018695 -->\n",
              "<g id=\"node6\" class=\"node\">\n",
              "<title>var5980780305148018695</title>\n",
              "<ellipse fill=\"none\" stroke=\"black\" cx=\"126\" cy=\"-234\" rx=\"27\" ry=\"18\"/>\n",
              "<text text-anchor=\"middle\" x=\"126\" y=\"-230.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">S7</text>\n",
              "</g>\n",
              "<!-- var5980780305148018695&#45;&gt;var4755801206503243780 -->\n",
              "<g id=\"edge3\" class=\"edge\">\n",
              "<title>var5980780305148018695&#45;&gt;var4755801206503243780</title>\n",
              "<path fill=\"none\" stroke=\"black\" d=\"M129.38,-215.87C134.17,-191.56 142.99,-146.82 148.67,-118.01\"/>\n",
              "<polygon fill=\"black\" stroke=\"black\" points=\"152.11,-118.68 150.61,-108.19 145.24,-117.32 152.11,-118.68\"/>\n",
              "</g>\n",
              "<!-- var5980780305148018695&#45;&gt;var5476377146882523141 -->\n",
              "<g id=\"edge2\" class=\"edge\">\n",
              "<title>var5980780305148018695&#45;&gt;var5476377146882523141</title>\n",
              "<path fill=\"none\" stroke=\"black\" d=\"M119.6,-216.41C116.49,-208.34 112.67,-198.43 109.17,-189.35\"/>\n",
              "<polygon fill=\"black\" stroke=\"black\" points=\"112.4,-188.03 105.54,-179.96 105.87,-190.55 112.4,-188.03\"/>\n",
              "</g>\n",
              "<!-- var6052837899185946630&#45;&gt;var4971973988617027587 -->\n",
              "<g id=\"edge4\" class=\"edge\">\n",
              "<title>var6052837899185946630&#45;&gt;var4971973988617027587</title>\n",
              "<path fill=\"none\" stroke=\"black\" d=\"M38.93,-145.81C46.21,-136.55 55.66,-124.52 63.85,-114.09\"/>\n",
              "<polygon fill=\"black\" stroke=\"black\" points=\"66.66,-116.18 70.09,-106.16 61.16,-111.86 66.66,-116.18\"/>\n",
              "</g>\n",
              "</g>\n",
              "</svg>\n"
            ],
            "text/plain": [
              "<graphviz.sources.Source at 0x18b782e0520>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Log Probability of all zeros: -1.2366269421045588\n",
            "Sampled Values (basic print):\n",
            "DiscreteValues{4683743612465315848: 0, 4755801206503243780: 1, 4899916394579099649: 1, 4971973988617027587: 0, 5476377146882523141: 0, 5980780305148018695: 1, 6052837899185946630: 0, 6341068275337658370: 0}\n",
            "Sampled Values (pretty print):\n",
            "  A8: 0\n",
            "  B4: 1\n",
            "  D1: 1\n",
            "  E3: 0\n",
            "  L5: 0\n",
            "  S7: 1\n",
            "  T6: 0\n",
            "  X2: 0\n"
          ]
        }
      ],
      "source": [
        "# Visualize the network structure\n",
        "dot_string = asia_net.dot()\n",
        "display(graphviz.Source(dot_string))\n",
        "\n",
        "# Evaluate the log probability of a specific assignment\n",
        "# Example: Calculate P(A=0, S=0, T=0, L=0, B=0, E=0, X=0, D=0)\n",
        "values = DiscreteValues()\n",
        "for key, card in cardinalities.items():\n",
        "    values[key] = 0 # Assign 0 to all variables to start\n",
        "\n",
        "log_prob_zeros = asia_net.logProbability(values)\n",
        "print(f\"Log Probability of all zeros: {log_prob_zeros}\")\n",
        "\n",
        "# Sample from the Bayes Net\n",
        "sample = asia_net.sample()\n",
        "print(\"Sampled Values (basic print):\")\n",
        "print(sample)\n",
        "\n",
        "# --- Pretty Print ---\n",
        "print(\"Sampled Values (pretty print):\")\n",
        "# Create a reverse mapping from integer key to string like 'A8'\n",
        "# We defined A=symbol('A',8), S=symbol('S',7), etc. above\n",
        "symbol_map = { A: 'A8', S: 'S7', T: 'T6', L: 'L5', B: 'B4', E: 'E3', X: 'X2', D: 'D1' }\n",
        "# Iterate through the sampled values and print nicely\n",
        "# Sort items by the symbol string for consistent order (optional)\n",
        "for key, value in sorted(sample.items(), key=lambda item: symbol_map.get(item[0], str(item[0]))):\n",
        "    symbol_str = symbol_map.get(key, f\"UnknownKey({key})\") # Get 'A8' from key A\n",
        "    print(f\"  {symbol_str}: {value}\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "gtsam",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.13.1"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}